{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "57a09e0b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizerFast(name_or_path='hfl/rbt3', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第11章/加载编码器\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('hfl/rbt3')\n",
    "\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cd4a0782",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS] 海 钓 比 赛 地 点 在 厦 门 与 金 门 之 间 的 海 域 。 [SEP]\n",
      "[CLS] 这 座 依 山 傍 水 的 博 物 馆 由 国 内 一 流 的 设 计 [SEP]\n",
      "input_ids tf.Tensor(\n",
      "[[ 101 3862 7157 3683 6612 1765 4157 1762 1336 7305  680 7032 7305  722\n",
      "  7313 4638 3862 1818  511  102]\n",
      " [ 101 6821 2429  898 2255  988 3717 4638 1300 4289 7667 4507 1744 1079\n",
      "   671 3837 4638 6392 6369  102]], shape=(2, 20), dtype=int32)\n",
      "token_type_ids tf.Tensor(\n",
      "[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
      " [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]], shape=(2, 20), dtype=int32)\n",
      "attention_mask tf.Tensor(\n",
      "[[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n",
      " [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]], shape=(2, 20), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "#第11章/编码测试\n",
    "out = tokenizer.batch_encode_plus(\n",
    "    [[\n",
    "        '海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间',\n",
    "        '的', '海', '域', '。'\n",
    "    ],\n",
    "     [\n",
    "         '这', '座', '依', '山', '傍', '水', '的', '博', '物', '馆', '由', '国', '内', '一',\n",
    "         '流', '的', '设', '计', '师', '主', '持', '设', '计', '。'\n",
    "     ]],\n",
    "    truncation=True,\n",
    "    padding=True,\n",
    "    return_tensors='tf',\n",
    "    max_length=20,\n",
    "    is_split_into_words=True)\n",
    "\n",
    "#还原编码为句子\n",
    "print(tokenizer.decode(out['input_ids'][0]))\n",
    "print(tokenizer.decode(out['input_ids'][1]))\n",
    "\n",
    "for k, v in out.items():\n",
    "    print(k, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c133bdbb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'tokens', 'ner_tags'],\n",
       "    num_rows: 20865\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第11章/获取数据集\n",
    "from datasets import load_dataset, load_from_disk\n",
    "\n",
    "\n",
    "def get_dataset(split):\n",
    "    #在线加载数据集\n",
    "    #dataset = load_dataset(path='peoples_daily_ner', split=split)\n",
    "\n",
    "    #离线加载数据集\n",
    "    dataset = load_from_disk(dataset_path='./data/peoples_daily_ner')[split]\n",
    "\n",
    "    #打乱顺序\n",
    "    dataset = dataset.shuffle()\n",
    "\n",
    "    #dataset.features['ner_tags'].feature.num_classes\n",
    "    #7\n",
    "\n",
    "    #dataset.features['ner_tags'].feature.names\n",
    "    #['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']\n",
    "\n",
    "    return dataset\n",
    "\n",
    "\n",
    "dataset = get_dataset('train')\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "590f5310",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#第11章/定义数据加载函数\n",
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "def get_batch_data(dataset, idx, batch_size):\n",
    "    idx_from = idx * batch_size\n",
    "    idx_to = idx_from + batch_size\n",
    "\n",
    "    if idx_to > dataset.num_rows:\n",
    "        return None, None\n",
    "\n",
    "    data = dataset[idx_from:idx_to]\n",
    "\n",
    "    #编码数据\n",
    "    inputs = tokenizer.batch_encode_plus(data['tokens'],\n",
    "                                         truncation=True,\n",
    "                                         padding=True,\n",
    "                                         return_tensors='tf',\n",
    "                                         max_length=512,\n",
    "                                         is_split_into_words=True)\n",
    "\n",
    "    labels = data['ner_tags']\n",
    "\n",
    "    #求一批数据中最长的句子长度\n",
    "    lens = inputs['input_ids'].shape[1]\n",
    "\n",
    "    #在labels的头尾补充7，把所有的labels补充成统一的长度\n",
    "    for i in range(len(labels)):\n",
    "        labels[i] = [7] + labels[i]\n",
    "        labels[i] += [7] * lens\n",
    "        labels[i] = labels[i][:lens]\n",
    "\n",
    "    labels = tf.constant(labels, dtype=tf.int32)\n",
    "\n",
    "    return inputs, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "10ffda93",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids (16, 112)\n",
      "token_type_ids (16, 112)\n",
      "attention_mask (16, 112)\n",
      "labels (16, 112)\n"
     ]
    }
   ],
   "source": [
    "#第11章/查看数据样例\n",
    "inputs, labels = get_batch_data(dataset, 0, 16)\n",
    "\n",
    "for k, v in inputs.items():\n",
    "    print(k, v.shape)\n",
    "\n",
    "print('labels', labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0f9eaff4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "All model checkpoint layers were used when initializing TFBertModel.\n",
      "\n",
      "All the layers of TFBertModel were initialized from the model checkpoint at hfl/rbt3.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"tf_bert_model\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "bert (TFBertMainLayer)       multiple                  38476800  \n",
      "=================================================================\n",
      "Total params: 38,476,800\n",
      "Trainable params: 38,476,800\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "#第11章/加载预训练模型\n",
    "from transformers import TFAutoModel\n",
    "\n",
    "pretrained = TFAutoModel.from_pretrained('hfl/rbt3')\n",
    "\n",
    "#查看模型概述\n",
    "pretrained.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f4b9a103",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([16, 112, 768])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第11章/模型试算\n",
    "#[b, lens] -> [b, lens, 768]\n",
    "pretrained(**inputs).last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fa466940",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([16, 112, 8])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第11章/定义下游模型\n",
    "class Model(tf.keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        #标识当前模型是否处于tuning模式\n",
    "        self.tuning = False\n",
    "        #当处于tuning模式时backbone应该属于当前模型的一部分，否则该变量为空\n",
    "        self.pretrained = None\n",
    "\n",
    "        #当前模型的神经网络层\n",
    "        self.rnn = tf.keras.layers.GRU(units=768, return_sequences=True)\n",
    "        self.fc = tf.keras.layers.Dense(units=8, activation=tf.nn.softmax)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        #根据当前模型是否处于tuning模式而使用外部backbone或内部backbone计算\n",
    "        if self.tuning:\n",
    "            out = self.pretrained(**inputs).last_hidden_state\n",
    "        else:\n",
    "            out = pretrained(**inputs).last_hidden_state\n",
    "\n",
    "        #backbone抽取的特征输入rnn网络进一步抽取特征\n",
    "        out = self.rnn(out)\n",
    "\n",
    "        #rnn网络抽取的特征最后输入fc神经网络分类\n",
    "        out = self.fc(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "    #切换下游任务模型的的tuning模式\n",
    "    def fine_tuning(self, tuning):\n",
    "        self.tuning = tuning\n",
    "        #tuning模式时，训练backbone的参数\n",
    "        if tuning:\n",
    "            self.pretrained = pretrained\n",
    "        #非tuning模式时，不训练backbone的参数\n",
    "        else:\n",
    "            self.pretrained = None\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "model(inputs).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f16c9703",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<tf.Tensor: shape=(6, 8), dtype=float32, numpy=\n",
       " array([[ 1.3714149e+00, -1.5095697e-04, -7.1241909e-01, -8.6933529e-01,\n",
       "         -5.9476590e-01, -8.1611431e-01, -2.4342534e-01,  1.1834431e+00],\n",
       "        [-6.7946464e-01, -8.9938432e-01,  1.8360443e-01,  1.0472615e+00,\n",
       "         -6.2940359e-01, -2.0146336e-01,  6.4231116e-01, -3.0778365e-02],\n",
       "        [ 9.2144525e-01, -3.3907911e-01, -9.2427254e-01,  1.5306507e+00,\n",
       "         -9.7463369e-01,  1.0550396e+00,  7.0500964e-01,  7.7894813e-01],\n",
       "        [ 4.3940777e-01,  4.8754442e-01, -5.4782212e-01,  7.3784602e-01,\n",
       "          5.3179342e-01,  5.2382642e-01, -9.4239825e-01,  5.2462798e-01],\n",
       "        [-6.3888496e-01, -7.4003971e-01,  8.7833917e-01,  9.1179144e-01,\n",
       "          9.7684518e-02, -1.0721865e+00,  1.2910798e+00, -4.8718882e-01],\n",
       "        [ 4.3508747e-01, -1.5385412e+00,  6.7557418e-01, -1.6573043e+00,\n",
       "         -1.3294674e+00,  4.0037945e-01, -5.7826185e-01, -8.4279853e-01]],\n",
       "       dtype=float32)>,\n",
       " <tf.Tensor: shape=(6,), dtype=float32, numpy=array([1., 1., 1., 1., 1., 1.], dtype=float32)>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第11章/对计算结果和label变形,并且移除pad\n",
    "def reshape_and_remove_pad(outs, labels, attention_mask):\n",
    "    #变形,便于计算loss\n",
    "    #[b, lens, 8] -> [b*lens, 8]\n",
    "    #[b, lens] -> [b*lens]\n",
    "    outs = tf.reshape(outs, [-1, 8])\n",
    "    labels = tf.reshape(labels, [-1])\n",
    "\n",
    "    #忽略对pad的计算结果\n",
    "    #[b, lens] -> [b*lens - pad]\n",
    "    select = tf.reshape(attention_mask, [-1]) == 1\n",
    "    outs = outs[select]\n",
    "    labels = labels[select]\n",
    "\n",
    "    return outs, labels\n",
    "\n",
    "\n",
    "reshape_and_remove_pad(tf.random.normal([2, 3, 8]), tf.ones([2, 3]),\n",
    "                       tf.ones([2, 3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "bca276cb",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 16, 2, 16)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第11章/获取正确数量和总数\n",
    "def get_correct_and_total_count(outs, labels):\n",
    "    #[b*lens, 8] -> [b*lens]\n",
    "    outs = tf.argmax(outs, axis=1, output_type=tf.int32)\n",
    "\n",
    "    correct = tf.cast(outs == labels, dtype=tf.int32)\n",
    "    correct = int(tf.reduce_sum(correct))\n",
    "\n",
    "    total = len(labels)\n",
    "\n",
    "    #计算除了0以外元素的正确率,因为0太多了,包括的话,正确率很容易虚高\n",
    "    select = labels != 0\n",
    "    outs = outs[select]\n",
    "    labels = labels[select]\n",
    "\n",
    "    correct_content = tf.cast(outs == labels, dtype=tf.int32)\n",
    "    correct_content = int(tf.reduce_sum(correct_content))\n",
    "\n",
    "    total_content = len(labels)\n",
    "\n",
    "    return correct, total, correct_content, total_content\n",
    "\n",
    "\n",
    "get_correct_and_total_count(tf.random.normal([16, 8]),\n",
    "                            tf.ones([16], dtype=tf.int32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "07be1427",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#第11章/训练\n",
    "from transformers import create_optimizer\n",
    "\n",
    "\n",
    "def train(epochs):\n",
    "    #创建优化器和学习率衰减工具\n",
    "    optimizer, schedule = create_optimizer(\n",
    "        #如果模型是tuning模式则使用更小的学习率\n",
    "        init_lr=2e-5 if model.tuning else 5e-4,\n",
    "        num_warmup_steps=0,\n",
    "        #统计总steps\n",
    "        num_train_steps=(dataset.num_rows // 16) * epochs)\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        i = 0\n",
    "        while True:\n",
    "            #取1个批次的数据\n",
    "            inputs, labels = get_batch_data(dataset, i, 16)\n",
    "            #如果没有取到数据，则说明数据已经遍历结束\n",
    "            if inputs == None:\n",
    "                break\n",
    "\n",
    "            #记录梯度变化\n",
    "            with tf.GradientTape() as tape:\n",
    "                #模型计算\n",
    "                #[b, lens] -> [b, lens, 8]\n",
    "                outs = model(inputs)\n",
    "\n",
    "                #对outs和label变形,并且移除pad\n",
    "                #outs -> [b, lens, 8] -> [c, 8]\n",
    "                #labels -> [b, lens] -> [c]\n",
    "                outs, labels = reshape_and_remove_pad(outs, labels,\n",
    "                                                      inputs['attention_mask'])\n",
    "\n",
    "                #计算loss\n",
    "                loss = tf.losses.categorical_crossentropy(\n",
    "                    y_true=tf.one_hot(labels, depth=8),\n",
    "                    y_pred=outs,\n",
    "                    from_logits=False,\n",
    "                    axis=1,\n",
    "                )\n",
    "                loss = tf.reduce_mean(loss)\n",
    "\n",
    "            #根据loss计算参数梯度\n",
    "            grads = tape.gradient(loss, model.trainable_variables)\n",
    "\n",
    "            #根据梯度更新参数\n",
    "            optimizer.apply_gradients(\n",
    "                (grad, var)\n",
    "                for (grad, var) in zip(grads, model.trainable_variables)\n",
    "                if grad is not None)\n",
    "\n",
    "            #衰减学习率\n",
    "            schedule(1)\n",
    "\n",
    "            if i % 50 == 0:\n",
    "                counts = get_correct_and_total_count(outs, labels)\n",
    "                accuracy = counts[0] / counts[1]\n",
    "                accuracy_content = counts[2] / counts[3]\n",
    "                lr = float(optimizer.lr(optimizer.iterations))\n",
    "\n",
    "                print(epoch, i, float(loss), lr, accuracy, accuracy_content)\n",
    "\n",
    "            i += 1\n",
    "\n",
    "        #保存模型参数\n",
    "        model.save_weights('model/tf_parameters/中文命名实体识别')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "23eee87d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "354.9704\n"
     ]
    }
   ],
   "source": [
    "#第11章/两段式训练第1步，训练下游任务模型\n",
    "model.fine_tuning(False)\n",
    "print(sum([int(tf.size(i)) for i in model.trainable_variables]) / 10000)\n",
    "#train(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b9fff3e7",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4202.6504\n"
     ]
    }
   ],
   "source": [
    "#第11章/两段式训练第2步，同时训练下游任务模型和预训练模型\n",
    "model.fine_tuning(True)\n",
    "print(sum([int(tf.size(i)) for i in model.trainable_variables]) / 10000)\n",
    "#train(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6183bfd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "0.9889047294365586 0.9555555555555556\n"
     ]
    }
   ],
   "source": [
    "#第11章/测试\n",
    "def test():\n",
    "    #加载训练完的模型参数\n",
    "    model.load_weights('model/tf_parameters/中文命名实体识别')\n",
    "\n",
    "    #测试数据集\n",
    "    dataset_test = get_dataset('test')\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    correct_content = 0\n",
    "    total_content = 0\n",
    "\n",
    "    #测试5个批次即可\n",
    "    for i in range(5):\n",
    "        print(i)\n",
    "        inputs, labels = get_batch_data(dataset_test, i, 128)\n",
    "\n",
    "        #计算\n",
    "        #[b, lens] -> [b, lens, 8] -> [b, lens]\n",
    "        outs = model(inputs)\n",
    "\n",
    "        #对outs和label变形,并且移除pad\n",
    "        #outs -> [b, lens, 8] -> [c, 8]\n",
    "        #labels -> [b, lens] -> [c]\n",
    "        outs, labels = reshape_and_remove_pad(outs, labels,\n",
    "                                              inputs['attention_mask'])\n",
    "\n",
    "        #统计正确数量\n",
    "        counts = get_correct_and_total_count(outs, labels)\n",
    "        correct += counts[0]\n",
    "        total += counts[1]\n",
    "        correct_content += counts[2]\n",
    "        total_content += counts[3]\n",
    "\n",
    "    print(correct / total, correct_content / total_content)\n",
    "\n",
    "\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "de33ffd4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS]当这一切辛勤劳作的成果不被承认或竟被剥夺，他便用笔控诉这种剥夺，于是有了他诗中的那种抗争的激情。[SEP]\n",
      "[CLS]7················································[SEP]7\n",
      "[CLS]7················································[SEP]7\n",
      "==========================\n",
      "[CLS]本报讯6月20日，红双喜中国乒乓球俱乐部甲级联赛大战7场，掀起一个小高潮。[SEP]\n",
      "[CLS]7·········红3双4喜4中5国6·······················[SEP]7\n",
      "[CLS]7·········红3双4·中5国4·······················[SEP]7\n",
      "==========================\n",
      "[CLS]粉煤灰用于煤气生产密封效果好本报讯黑龙江省牡丹江市煤气公司把粉煤灰和粘土按一定比例配制成焦炉密封材料，在国内同行业中开了先河。[SEP]\n",
      "[CLS]7·················黑5龙6江6省6牡3丹4江4市4煤4气4公4司4··································[SEP]7\n",
      "[CLS]7·················黑5龙6江6省6牡3丹4江4市4煤4气4公4司4··································[SEP]7\n",
      "==========================\n",
      "[CLS]求同存异：抗日战争时期，在同国民党的统一战线中，国民党许多不利于团结和动员人民抗战的错误政策，只要不是根本危及团结抗战的急迫问题，可以暂时求同存异，加以等待。[SEP]\n",
      "[CLS]7······日5·······国3民4党4·······国3民4党4····················································[SEP]7\n",
      "[CLS]7······日5·······国3民4党4·······国3民4党4····················································[SEP]7\n",
      "==========================\n",
      "[CLS]马克思主义的历史观并不是由什么人发明出来，而后从外部强加给历史的僵化的原则；它本身正是从无数的历史现象中抽象出来的对于历史发展的规律性的认识。[SEP]\n",
      "[CLS]7马1克2思2····································································[SEP]7\n",
      "[CLS]7马1克2思2····································································[SEP]7\n",
      "==========================\n",
      "[CLS]他猛抬头，怔一怔，两食指抵住两太阳穴对我说：[UNK]对不起，我这时候脑袋里全是古人的名字。[UNK][SEP]\n",
      "[CLS]7···········································[SEP]7\n",
      "[CLS]7···············太5阳6穴6·························[SEP]7\n",
      "==========================\n",
      "[CLS]董建华说，1997年7月1日，香港在举世注目下，昂然迈进一个新纪元，成为中华人民共和国的特别行政区。[SEP]\n",
      "[CLS]7董1建2华2············香5港6···················中5华6人6民6共6和6国6·······[SEP]7\n",
      "[CLS]7董1建2华2············香5港6···················中5华6人6民6共6和6国6·······[SEP]7\n",
      "==========================\n",
      "[CLS]美国前总统国家安全事务顾问罗伯特·麦克法兰日前对本社记者发表谈话，认为江泽民主席去年访美是一个分水岭，从此美中两国关系不断改善。[SEP]\n",
      "[CLS]7美5国6···········罗1伯2特2·2麦2克2法2兰2··············江1泽2民2·····美5·········美5中5·········[SEP]7\n",
      "[CLS]7美5国6···········罗1伯2特2·2麦2克2法2兰2··············江1泽2民2·····美5·········美5中5·········[SEP]7\n",
      "==========================\n",
      "[CLS]云居寺石经是从隋代高僧静琬开始至明末清初1000多年间刻出的，计14278块，共刻佛经1122部3572卷。[SEP]\n",
      "[CLS]7云5居6寺6········静1琬2·········································[SEP]7\n",
      "[CLS]7云5居6寺6········静1琬2·········································[SEP]7\n",
      "==========================\n",
      "[CLS]评选活动由《健康顾问》杂志主办，得到了深圳花仙子艺术设计有限公司的大力支持。[SEP]\n",
      "[CLS]7·····《3健4康4顾4问4》4杂4志4······深3圳4花4仙4子4艺4术4设4计4有4限4公4司4······[SEP]7\n",
      "[CLS]7·····《3健4康4顾4问4》4杂4志4······深3圳4花4仙4子4艺4术4设4计4有4限4公4司4······[SEP]7\n",
      "==========================\n",
      "[CLS]我们常常一早从桥儿沟鲁艺出发，通过飞机场，过延河到文化俱乐部；往往演出到深夜才又经过飞机场，踏着寂静和曲折的山路返回到鲁艺。[SEP]\n",
      "[CLS]7·······桥5儿6沟6鲁3艺4··········延5河6···································鲁3艺4·[SEP]7\n",
      "[CLS]7·······桥5儿6沟6鲁4艺4··········延5河4··化4··部4·····························鲁5艺4·[SEP]7\n",
      "==========================\n",
      "[CLS]经协商，人民出版社以它的副牌东方出版社的名义购买了这部书的版权，从今年2月起约请专人翻译。[SEP]\n",
      "[CLS]7····人3民4出4版4社4·····东3方4出4版4社4··························[SEP]7\n",
      "[CLS]7····人3民4出4版4社4·····东3方4出4版4社4··························[SEP]7\n",
      "==========================\n",
      "[CLS]维克斯集团总裁柯林·钱德勒在股东大会后对新闻界表示，对维克斯集团、劳斯莱斯汽车公司以及其员工来说，这是一笔好交易。[SEP]\n",
      "[CLS]7维3克4斯4集4团4··柯1林2·2钱2德2勒2··············维3克4斯4集4团4·劳3斯4莱4斯4汽4车4公4司4················[SEP]7\n",
      "[CLS]7维3克4斯4集4团4··柯1林2·2钱2德2勒2··············维3克4斯4集4团4·劳3斯4莱4斯4汽4车4公4司4················[SEP]7\n",
      "==========================\n",
      "[CLS]党的十五大报告中在阐述建设有中国特色社会主义时，强调了精神文明建设与文化建设之间的关系问题。[SEP]\n",
      "[CLS]7··十3五4大4·········中5国6······························[SEP]7\n",
      "[CLS]7··十3五4大4·········中5国6······························[SEP]7\n",
      "==========================\n",
      "[CLS]他说，巴基斯坦没有任何侵略别国的意图，但在事关国家安全的问题上决不可能作出任何妥协。[SEP]\n",
      "[CLS]7···巴5基6斯6坦6···································[SEP]7\n",
      "[CLS]7···巴5基6斯6坦6···································[SEP]7\n",
      "==========================\n",
      "[CLS]本报讯国际棋联世界冠军卡尔波夫将于5月中下旬再度来华访问表演。[SEP]\n",
      "[CLS]7···国3际4棋4联4····卡1尔2波2夫2··········华5·····[SEP]7\n",
      "[CLS]7···国3际4棋4联4····卡1尔2波2夫2··········华5·····[SEP]7\n",
      "==========================\n",
      "[CLS]像具有张裕公司这样的生产能力的公司一家就足可以[UNK]包圆儿[UNK]。[SEP]\n",
      "[CLS]7···张3裕4公4司4······················[SEP]7\n",
      "[CLS]7···张3裕4公4司4······················[SEP]7\n",
      "==========================\n",
      "[CLS]据了解，可能被窃走的这些软件用于协调与10多枚卫星相联系的军事全球定位系统，而这一系统是海湾战争以来美国军方发展的一项重要的军事技术，专门用于导弹准确地打击目标和确定军人自己所处的地理位置。[SEP]\n",
      "[CLS]7············································海5湾6····美5国6···········································[SEP]7\n",
      "[CLS]7·············································湾6····美5国6···········································[SEP]7\n",
      "==========================\n",
      "[CLS]位于准噶尔盆地东部的新疆野马饲养繁殖中心主任原洪说，截至目前，该中心从国外引进的18匹普氏野马，已成功繁殖野马81匹，繁殖成活率达84％，居世界首位。[SEP]\n",
      "[CLS]7··准5噶6尔6·····新5疆6野6马6饲6养6繁6殖6中6心6··原1洪2···················································[SEP]7\n",
      "[CLS]7··准5噶6尔6盆6····新5疆6野4马4饲4养4繁4殖4中4心4··原1洪2···················································[SEP]7\n",
      "==========================\n",
      "[CLS]信息差主要体现在预期质量和网络意识方面。[SEP]\n",
      "[CLS]7····················[SEP]7\n",
      "[CLS]7····················[SEP]7\n",
      "==========================\n",
      "[CLS]但是由于今天的失误，他的总体成绩会受到影响，可能会比预期的结果差些。[SEP]\n",
      "[CLS]7··································[SEP]7\n",
      "[CLS]7··································[SEP]7\n",
      "==========================\n",
      "[CLS]因此，要更加自觉更加刻苦地学好邓小平党建理论，真正弄懂弄通这一理论的精神实质，掌握其基本的立场、观点、方法，以及研究新情况、解决新问题、开拓新境界的科学态度、创造精神和革命风格，牢牢把握解放思想、实事求是这个精髓。[SEP]\n",
      "[CLS]7···············邓1小2平2·························································································[SEP]7\n",
      "[CLS]7···············邓1小2平2·························································································[SEP]7\n",
      "==========================\n",
      "[CLS]李岚清副总理1995年访问非洲时说：[UNK]首先要改变那种轻视非洲市场、不愿意下功夫去挤占市场的倾向，推动中国企业去办一些投资小、见效快、效益好的独资或合资企业，如家用电器、小汽车、卡车和小型拖拉机、日用消费品等。[UNK][SEP]\n",
      "[CLS]7李1岚2清2··········非5洲6·············非5洲6····················中5国6·····················································[SEP]7\n",
      "[CLS]7李1岚2清2··········非5洲6·············非5洲6····················中5国6·····················································[SEP]7\n",
      "==========================\n",
      "[CLS]此外，他还强化监督机制建设，在系统内设立了97名廉政监察员，对外聘请了445人为廉政督察员，向全市6万多个纳税户发放监督表，请社会公开评议税风。[SEP]\n",
      "[CLS]7········································································[SEP]7\n",
      "[CLS]7········································································[SEP]7\n",
      "==========================\n",
      "[CLS]他们是守门员塔法雷尔、卡洛斯·热尔马诺和迪达；后卫阿尔代尔、贡萨尔维斯、儒尼奥尔·拜伊亚诺和马尔西奥·桑托斯、卡福、罗伯特·卡洛斯和泽·罗伯特；中场队员塞萨尔·桑巴约、德尼尔森、多里瓦、邓加、弗拉维奥、乔瓦尼、莱奥纳多和里瓦尔多；前锋贝贝托、埃德蒙多、罗纳尔多和罗马里奥。[SEP]\n",
      "[CLS]7······塔1法2雷2尔2·卡1洛2斯2·2热2尔2马2诺2·迪1达2···阿1尔2代2尔2·贡1萨2尔2维2斯2·儒1尼2奥2尔2·2拜2伊2亚2诺2·马1尔2西2奥2·2桑2托2斯2·卡1福2·罗1伯2特2·2卡2洛2斯2·泽1·2罗2伯2特2·····塞1萨2尔2·2桑2巴2约2·德1尼2尔2森2·多1里2瓦2·邓1加2·弗1拉2维2奥2·乔1瓦2尼2·莱1奥2纳2多2·里1瓦2尔2多2···贝1贝2托2·埃1德2蒙2多2·罗1纳2尔2多2·罗1马2里2奥2·[SEP]7\n",
      "[CLS]7······塔1法2雷2尔2·卡1洛2斯2·2热2尔2马2诺2·迪1达2···阿1尔2代2尔2·贡1萨2尔2维2斯2·儒1尼2奥2尔2·2拜2伊2亚2诺2·马1尔2西2奥2·2桑2托2斯2·卡1福2·罗1伯2特2·2卡2洛2斯2·泽1·2罗2伯2特2·····塞1萨2尔2·2桑2巴2约2·德1尼2尔2森2·多1里2瓦2·邓1加2·弗1拉2维2奥2·乔1瓦2尼2·莱1奥2纳2多2·里1瓦2尔2多2···贝1贝2托2·埃1德2蒙2多2·罗1纳2尔2多2·罗1马2里2奥2·[SEP]7\n",
      "==========================\n",
      "[CLS][UNK]用户永远是对的[UNK]是海尔的售后服务观。[SEP]\n",
      "[CLS]7··········海3尔4·······[SEP]7\n",
      "[CLS]7··········海3尔4·······[SEP]7\n",
      "==========================\n",
      "[CLS]而仍像个大男孩似的他，还是老想法：荣耀再多，皆成往事；只有把舞步跳向崎岖的未来，才最为诱人[UNK][UNK][SEP]\n",
      "[CLS]7···············································[SEP]7\n",
      "[CLS]7···············································[SEP]7\n",
      "==========================\n",
      "[CLS]从接轨角度看，这场风暴是经济全球化进程中区域金融机制的强行矫正；从当事人角度看，这是一场国际财富的强行再分配；从被冲击方看，这是增长奇迹的暂时中断、经济泡沫的迅速释放。[SEP]\n",
      "[CLS]7····················································································[SEP]7\n",
      "[CLS]7····················································································[SEP]7\n",
      "==========================\n",
      "[CLS]龙永图说，尽管本地区受到了金融危机的严重冲击，中国的经贸发展也受到了不小的负面影响，但中国改革开放的步伐不会停止。[SEP]\n",
      "[CLS]7龙1永2图2····················中5国6··················中5国6············[SEP]7\n",
      "[CLS]7龙1永2图2····················中5国6··················中5国6············[SEP]7\n",
      "==========================\n",
      "[CLS]奥运会冠军占旭刚、世界冠军崔文华以及刚刚获得世界杯总决赛冠军的兰世章等10名运动员，依次登台向大家介绍各自家乡发生的巨大变化。[SEP]\n",
      "[CLS]7奥5····占1旭2刚2·····崔1文2华2···············兰1世2章2·····························[SEP]7\n",
      "[CLS]7奥5····占1旭2刚2·····崔1文2华2···············兰1世2章2·····························[SEP]7\n",
      "==========================\n",
      "[CLS]乡局级后备干部队伍人数达五百三十人，比上年同期增长百分之四十一点七。[SEP]\n",
      "[CLS]7··································[SEP]7\n",
      "[CLS]7··································[SEP]7\n",
      "==========================\n",
      "[CLS]其三，协议能否有效执行尚不确定。[SEP]\n",
      "[CLS]7················[SEP]7\n",
      "[CLS]7················[SEP]7\n",
      "==========================\n"
     ]
    }
   ],
   "source": [
    "#第11章/预测\n",
    "def predict():\n",
    "    #加载训练完的模型参数\n",
    "    model.load_weights('model/tf_parameters/中文命名实体识别')\n",
    "\n",
    "    #测试数据集\n",
    "    dataset_test = get_dataset('test')\n",
    "\n",
    "    #取一个批次的数据\n",
    "    inputs, labels = get_batch_data(dataset_test, 0, 32)\n",
    "\n",
    "    #计算\n",
    "    #[b, lens] -> [b, lens, 8] -> [b, lens]\n",
    "    outs = model(inputs)\n",
    "    outs = tf.argmax(outs, axis=2, output_type=tf.int32)\n",
    "\n",
    "    for i in range(32):\n",
    "        #移除pad\n",
    "        select = inputs['attention_mask'][i] == 1\n",
    "\n",
    "        input_id = tf.boolean_mask(inputs['input_ids'][i], axis=0, mask=select)\n",
    "        out = tf.boolean_mask(outs[i], axis=0, mask=select)\n",
    "        label = tf.boolean_mask(labels[i], axis=0, mask=select)\n",
    "\n",
    "        #输出原句子\n",
    "        print(tokenizer.decode(input_id).replace(' ', ''))\n",
    "\n",
    "        #输出tag\n",
    "        for tag in [label, out]:\n",
    "            s = ''\n",
    "            for j in range(len(tag)):\n",
    "                if tag[j] == 0:\n",
    "                    s += '·'\n",
    "                    continue\n",
    "                s += tokenizer.decode(input_id[j])\n",
    "                s += str(int(tag[j]))\n",
    "\n",
    "            print(s)\n",
    "        print('==========================')\n",
    "\n",
    "\n",
    "predict()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
